import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator

def cal_attn_diff(attn_container, k_layer, k_head, k_lh):
    # 1) collect all attention tensors of ori & rand
    ori_list, rand_list = [], []
    for img_name, attn_dict in attn_container.items():
        # make sure base_attn\pert_attn
        if "base_attn" not in attn_dict or "pert_attn" not in attn_dict:
            print(f"Warning: image {img_name} has no base_attn or pert_attn, continue")
            continue
        base_attn = attn_dict["base_attn"] if torch.is_tensor(attn_dict["base_attn"]) else torch.tensor(attn_dict["base_attn"])
        pert_attn = attn_dict["pert_attn"] if torch.is_tensor(attn_dict["pert_attn"]) else torch.tensor(attn_dict["pert_attn"])

        ori_list.append(base_attn)
        rand_list.append(pert_attn)

    if len(ori_list) == 0 or len(rand_list) == 0:
        raise ValueError("cant find any `base_attn` or `pert_attn` ")

    # test same shape
    first_shape = ori_list[0].shape
    for attn in ori_list + rand_list:
        if attn.shape != first_shape:
            raise ValueError(f"attn shape is not same, expected {first_shape}, get {attn.shape}")
    

    # 2) stack and get mean
    ori_stack  = torch.stack(ori_list,  dim=0)  # [N_images, layers, heads]
    rand_stack = torch.stack(rand_list, dim=0)  # [N_images, layers, heads]
    
    # head mean then layer mean
    ori_value_mean  = ori_stack.mean(dim=0)     # [layers, heads]
    rand_value_mean = rand_stack.mean(dim=0)   # [layers, heads]
    
    # layer mean
    ori_layer_mean = ori_value_mean.mean(dim=1)  # [layers]
    rand_layer_mean = rand_value_mean.mean(dim=1)  # [layers]
    diff_layer = ori_layer_mean - rand_layer_mean  # [layers]
    
    diff = ori_value_mean - rand_value_mean
    diff_head_mean = diff.mean(dim=0)

    # get topk and indexes
    topk_layer_values, topk_layer_indices = torch.topk(diff_layer, k_layer)
    topk_head_values, topk_head_indices = torch.topk(diff_head_mean, k_head)

    # get topk and indexes(layer-head wise)
    L, H = diff.shape
    flat_diff = diff.view(-1)   #shape = [L*H]
    topk_lh_values, topk_lh_flat_indices = torch.topk(flat_diff, k_lh)

    topk_lh_indices = []
    for idx in topk_lh_flat_indices.tolist():
        layer_idx = idx // H
        heads_idx = idx % H
        topk_lh_indices.append((layer_idx, heads_idx))

    return (rand_layer_mean, ori_layer_mean, diff_layer,  # layer-wise
            rand_value_mean, ori_value_mean, diff,  # ori
            topk_layer_values, topk_layer_indices, 
            topk_head_values, topk_head_indices,
            topk_lh_values, topk_lh_indices)

def plot_values(rand_layer_mean, ori_layer_mean, diff_layer, figure_name, save_path):
    """
    plot mean and diff
    
    rand_layer_mean, ori_layer_mean, diff_layer: Tensor 或 ndarray，shape = [layers]
    figure_name: name of plot
    save_path: save dir
    """
    os.makedirs(save_path, exist_ok=True)
    full_path = os.path.join(save_path, figure_name)

    # turn into numpy
    if hasattr(rand_layer_mean, 'cpu'):
        rv = rand_layer_mean.cpu().numpy()
        ov = ori_layer_mean.cpu().numpy()
        df = diff_layer.cpu().numpy()
    else:
        rv, ov, df = rand_layer_mean, ori_layer_mean, diff_layer

    fig, axes = plt.subplots(1, 3, figsize=(12, 8))
    titles = ['Ori Prompt', 'Rand Prompt', 'Difference (Ori - Rand)']
    data_list = [ov, rv, df]
    cmaps = ['viridis', 'viridis', 'coolwarm']

    #  [layers, 1]
    for ax, data, title, cmap in zip(axes, data_list, titles, cmaps):
        data_2d = data.reshape(-1, 1)
        im = ax.imshow(data_2d, aspect='auto', cmap=cmap)
        ax.set_title(title, fontsize=14)
        ax.set_xlabel('') 
        ax.set_ylabel('Layer Index', fontsize=14)
        ax.set_xticks([]) 

        ax.set_yticks(range(len(data)))
        ax.set_yticklabels([str(i) for i in range(len(data))], fontsize=8)
        
        cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.ax.tick_params(labelsize=8)
        
        ax.yaxis.set_major_locator(MultipleLocator(5))

    plt.tight_layout()
    plt.savefig(full_path, dpi=1200)
    plt.close(fig)
    print(f"{figure_name} saved in {full_path}")



if __name__ == "__main__":
    # 1. load path
    pt_path = "results/attn_weights_output_1k.pt"
    figure_save_path = "results/figures"
    figure_name = "heatmap_layer_level_attn_weights-1k_layer_wise.png" 
    k_layer = 10
    k_head = 12
    k_lh = 5

    attn_container = torch.load(pt_path)
    # get results
    (rand_layer_mean, ori_layer_mean, diff_layer,
     rand_value_mean, ori_value_mean, diff,
     topk_layer_values, topk_layer_indices,
     topk_head_values, topk_head_indices,
     topk_lh_values, topk_lh_indices) = cal_attn_diff(attn_container, k_layer, k_head, k_lh)

    # get plot
    plot_values(rand_layer_mean, ori_layer_mean, diff_layer, figure_name, figure_save_path)

    # print results
    print(f"\ntop{k_layer}_diff_layer:")
    for i, (value, idx) in enumerate(zip(topk_layer_values, topk_layer_indices)):
        print(f"layer {idx.item()}: diff = {value.item():.4f}")

    print(f"\ntop{k_head}_diff_head:")
    for i, (value, idx) in enumerate(zip(topk_head_values, topk_head_indices)):
        print(f"head {idx.item()}: diff = {value.item():.4f}")   

    print(f"\ntop{k_lh}_diff_layer&head:")
    for i, (value, (layer, head)) in enumerate(zip(topk_lh_values, topk_lh_indices)):
        print(f" layer {layer}, head {head}: diff = {value.item():.5f}")       